{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot import Collection\n",
    "from cot.stats import evaluation_as_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.from_json(\"./chatgpt_hard_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a0fc2da14b242f7a0faf01c16c7edbc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0772e1e835bb40009e63ff80ac1a3d18",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0f8763e16b8c4027baf907388fd49ecb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42f3f63c5a8f494ea28f4ed19c16b0c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49fe3b9ff3a24204834956cb403254b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 1       | -      |\n",
       "| med_qa         | -       | -       | 1      |\n",
       "| open_book_qa   | -       | -       | 1      |\n",
       "| strategy_qa    | 1       | -       | -      |\n",
       "| worldtree      | -       | -       | 1      |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'medmc_qa', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll = coll.select(\"all\", 1)\n",
    "coll.delete_all_generated_cots()\n",
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/kon/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "coll = Collection.from_json(\"/home/kon/work/ThoughtSource/notebooks/chatgpt_hard_dataset_qa-10_to_qa-18.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": \"qa-19\",\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"api_time_interval\": 1,\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": None,\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"template_cot_generation\":\n",
    "    \"\"\"\n",
    "    Instruction: \n",
    "    (1) Let's work this out in a step by step way to be sure we have the right answer.\n",
    "    (2) Apply all those questions to your reasoning.\n",
    "    Does the response adequately address the request?\n",
    "    Is the response plausible?\n",
    "    Does the response contain any irrelevant or incorrect information?\n",
    "    Does the response exhibit reading comprehension errors, such as misinterpreting or misrepresenting information?\n",
    "    Are there any unnecessary repetitions in the response?\n",
    "    Does the response omit any steps in the process of reasoning?\n",
    "    Is the reasoning logically valid, consistent and coherent?\n",
    "    Are statements in the response made with appropriate levels of confidence?\n",
    "    Is the reasoning free of cognitive biases or fallacies?\n",
    "    Does the reasoning resemble that of a psychologically balanced, calm and reasonable person?\n",
    "    If the response contains formal reasoning (e.g., math, computer code), were any errors made?\n",
    "    If external tools (e.g., search engines, APIs, mathematical/statistical tools) were used in the response, were they used correctly?\n",
    "    If the response contains step-by-step reasoning, could better reasoning steps or sub-questions have been chosen?\n",
    "    \\\\n\\\\n{question}\\\\n{answer_choices}\\\\n\\\\n{cot_trigger}\n",
    "    \"\"\",\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"api_time_interval\": 1,\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": None,\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"template_cot_generation\":\n",
    "    \"\"\"\n",
    "    Instruction: \n",
    "    (1) Let's work this out in a step by step way to be sure we have the right answer.\n",
    "    (2) Apply all of those questions to your reasoning:\n",
    "    What is the goal of this request?\n",
    "    What level of detail and complexity is expected in the response?\n",
    "    What level of certainty is expected in the response?\n",
    "    Is the response expected to be accompanied by supporting evidence or references?\n",
    "    Is responding to the request straightforward or does it require step-by-step reasoning or raising sub-questions?\n",
    "    If step-by-step reasoning or sub-questions are needed, which reasoning steps should be taken, or which sub-questions should be raised?\n",
    "    What resources and methods are available to respond to the request?\n",
    "    \\\\n\\\\n{question}\\\\n{answer_choices}\\\\n\\\\n{cot_trigger}\n",
    "    \"\"\",\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"api_time_interval\": 1,\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "befbc4b5e10f4ff0b0e6f9a12f7ee444",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b896afca58154946b2819f8963661727",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "846ef418fcc04f368e8217bef2935164",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "38561f56f2604bada6f383b6a33d2baf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2153669bad324f5da7f8463468aac415",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.generate(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b0921f55ebf4afcbbb82dba4f04dcce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3de540ad40fc44f6962277ef46deb82d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2537ee0c4ddb4ada861391d570dc2451",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d407271f9dc347f7a1a5f3d4bd9d4e9f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f9489e69c4c4e6f8dc782a79b5ad271",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_d9760_row0_col0, #T_d9760_row1_col1, #T_d9760_row2_col2, #T_d9760_row3_col0, #T_d9760_row4_col0, #T_d9760_row5_col1 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_d9760\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_d9760_level0_col0\" class=\"col_heading level0 col0\" >qa-10_None</th>\n",
       "      <th id=\"T_d9760_level0_col1\" class=\"col_heading level0 col1\" >qa-11_None</th>\n",
       "      <th id=\"T_d9760_level0_col2\" class=\"col_heading level0 col2\" >qa-12_None</th>\n",
       "      <th id=\"T_d9760_level0_col3\" class=\"col_heading level0 col3\" >qa-13_None</th>\n",
       "      <th id=\"T_d9760_level0_col4\" class=\"col_heading level0 col4\" >qa-14_None</th>\n",
       "      <th id=\"T_d9760_level0_col5\" class=\"col_heading level0 col5\" >qa-15_None</th>\n",
       "      <th id=\"T_d9760_level0_col6\" class=\"col_heading level0 col6\" >qa-16_None</th>\n",
       "      <th id=\"T_d9760_level0_col7\" class=\"col_heading level0 col7\" >qa-17_None</th>\n",
       "      <th id=\"T_d9760_level0_col8\" class=\"col_heading level0 col8\" >qa-18_None</th>\n",
       "      <th id=\"T_d9760_level0_col9\" class=\"col_heading level0 col9\" >qa-19_None</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_d9760_level1_col0\" class=\"col_heading level1 col0\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col1\" class=\"col_heading level1 col1\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col3\" class=\"col_heading level1 col3\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col4\" class=\"col_heading level1 col4\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col5\" class=\"col_heading level1 col5\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col6\" class=\"col_heading level1 col6\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col7\" class=\"col_heading level1 col7\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col8\" class=\"col_heading level1 col8\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d9760_level1_col9\" class=\"col_heading level1 col9\" >gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_d9760_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_d9760_row0_col0\" class=\"data row0 col0\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col1\" class=\"data row0 col1\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col2\" class=\"data row0 col2\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col3\" class=\"data row0 col3\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col4\" class=\"data row0 col4\" >0.00</td>\n",
       "      <td id=\"T_d9760_row0_col5\" class=\"data row0 col5\" >0.00</td>\n",
       "      <td id=\"T_d9760_row0_col6\" class=\"data row0 col6\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col7\" class=\"data row0 col7\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col8\" class=\"data row0 col8\" >1.00</td>\n",
       "      <td id=\"T_d9760_row0_col9\" class=\"data row0 col9\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d9760_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_d9760_row1_col0\" class=\"data row1 col0\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col1\" class=\"data row1 col1\" >1.00</td>\n",
       "      <td id=\"T_d9760_row1_col2\" class=\"data row1 col2\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col3\" class=\"data row1 col3\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col4\" class=\"data row1 col4\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col5\" class=\"data row1 col5\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col6\" class=\"data row1 col6\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col7\" class=\"data row1 col7\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col8\" class=\"data row1 col8\" >0.00</td>\n",
       "      <td id=\"T_d9760_row1_col9\" class=\"data row1 col9\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d9760_level0_row2\" class=\"row_heading level0 row2\" >open_book_qa</th>\n",
       "      <td id=\"T_d9760_row2_col0\" class=\"data row2 col0\" >0.00</td>\n",
       "      <td id=\"T_d9760_row2_col1\" class=\"data row2 col1\" >0.00</td>\n",
       "      <td id=\"T_d9760_row2_col2\" class=\"data row2 col2\" >1.00</td>\n",
       "      <td id=\"T_d9760_row2_col3\" class=\"data row2 col3\" >1.00</td>\n",
       "      <td id=\"T_d9760_row2_col4\" class=\"data row2 col4\" >0.00</td>\n",
       "      <td id=\"T_d9760_row2_col5\" class=\"data row2 col5\" >1.00</td>\n",
       "      <td id=\"T_d9760_row2_col6\" class=\"data row2 col6\" >0.00</td>\n",
       "      <td id=\"T_d9760_row2_col7\" class=\"data row2 col7\" >0.00</td>\n",
       "      <td id=\"T_d9760_row2_col8\" class=\"data row2 col8\" >0.00</td>\n",
       "      <td id=\"T_d9760_row2_col9\" class=\"data row2 col9\" >1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d9760_level0_row3\" class=\"row_heading level0 row3\" >strategy_qa</th>\n",
       "      <td id=\"T_d9760_row3_col0\" class=\"data row3 col0\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col1\" class=\"data row3 col1\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col2\" class=\"data row3 col2\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col3\" class=\"data row3 col3\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col4\" class=\"data row3 col4\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col5\" class=\"data row3 col5\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col6\" class=\"data row3 col6\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col7\" class=\"data row3 col7\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col8\" class=\"data row3 col8\" >0.00</td>\n",
       "      <td id=\"T_d9760_row3_col9\" class=\"data row3 col9\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d9760_level0_row4\" class=\"row_heading level0 row4\" >worldtree</th>\n",
       "      <td id=\"T_d9760_row4_col0\" class=\"data row4 col0\" >1.00</td>\n",
       "      <td id=\"T_d9760_row4_col1\" class=\"data row4 col1\" >1.00</td>\n",
       "      <td id=\"T_d9760_row4_col2\" class=\"data row4 col2\" >0.00</td>\n",
       "      <td id=\"T_d9760_row4_col3\" class=\"data row4 col3\" >1.00</td>\n",
       "      <td id=\"T_d9760_row4_col4\" class=\"data row4 col4\" >0.00</td>\n",
       "      <td id=\"T_d9760_row4_col5\" class=\"data row4 col5\" >1.00</td>\n",
       "      <td id=\"T_d9760_row4_col6\" class=\"data row4 col6\" >1.00</td>\n",
       "      <td id=\"T_d9760_row4_col7\" class=\"data row4 col7\" >1.00</td>\n",
       "      <td id=\"T_d9760_row4_col8\" class=\"data row4 col8\" >0.00</td>\n",
       "      <td id=\"T_d9760_row4_col9\" class=\"data row4 col9\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d9760_level0_row5\" class=\"row_heading level0 row5\" >Average</th>\n",
       "      <td id=\"T_d9760_row5_col0\" class=\"data row5 col0\" >0.40</td>\n",
       "      <td id=\"T_d9760_row5_col1\" class=\"data row5 col1\" >0.60</td>\n",
       "      <td id=\"T_d9760_row5_col2\" class=\"data row5 col2\" >0.40</td>\n",
       "      <td id=\"T_d9760_row5_col3\" class=\"data row5 col3\" >0.60</td>\n",
       "      <td id=\"T_d9760_row5_col4\" class=\"data row5 col4\" >0.00</td>\n",
       "      <td id=\"T_d9760_row5_col5\" class=\"data row5 col5\" >0.40</td>\n",
       "      <td id=\"T_d9760_row5_col6\" class=\"data row5 col6\" >0.40</td>\n",
       "      <td id=\"T_d9760_row5_col7\" class=\"data row5 col7\" >0.40</td>\n",
       "      <td id=\"T_d9760_row5_col8\" class=\"data row5 col8\" >0.20</td>\n",
       "      <td id=\"T_d9760_row5_col9\" class=\"data row5 col9\" >0.20</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fc1680c7460>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c473a818e0cd4d158fbf255393fb48da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78597bb1843f45aeb262bc049e8a8601",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "addfed3533c54ebfa7b6f7ede9917c95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "462d1437510e4047a06a73fcc24c80e1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54ebdf98a8404d4985ce99c633d8688b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.dump(\"chatgpt_hard_dataset_qa-10_to_qa-19\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0041816d94847a79b3a98c5bdbec8f9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "248e5d8a292e46d592a13a864ccc0749",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "828701c188c14a468e4a92853f5eac11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02a83d35c17c49258c3e1e2c94ceee73",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f85977154de743ca88235c58e028c23f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.dump(\"/home/kon/work/ThoughtSource/notebooks/chatgpt_hard_dataset_qa-01_to_qa-09_zhou-01-ins_zhou-01.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.from_json(\"/home/kon/work/ThoughtSource/notebooks/chatgpt_hard_dataset_qa-01_to_qa-09_zhou-01-ins.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge with original dataset\n",
    "coll_1 = Collection.from_json(\"./chatgpt_hard_dataset.json\")\n",
    "coll_1 = coll_1.select(\"all\", 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7d54e2dd4c7349aca13b1f6cb53df6b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3abe641505444aecb25811f186d1aed9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "748a53ed8b9149158169c39fe6674c47",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6a4fa54b8d54e4fa363a44661f8c631",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a666c347a5fe40c4853848fea8795a7b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5045911f67347bca523866c5c6966ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "980edba6600947538783900ff3c122ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d9b361dfbf34f8492ce313a6630ca4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7963b5d3486b4c9a865892f642de5b9a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "850619ffe15e4a0da69cc46671e41ff6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll = coll.merge(coll_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/kon/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "ts_100 = Collection.load_thoughtsource_100()\n",
    "ts_100.select_generated_cots(author=\"thoughtsource\", model = \"gpt-3.5-turbo\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_4e33d_row0_col0, #T_4e33d_row1_col4, #T_4e33d_row2_col0, #T_4e33d_row3_col4, #T_4e33d_row4_col11, #T_4e33d_row5_col0, #T_4e33d_row6_col0 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_4e33d\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_4e33d_level0_col0\" class=\"col_heading level0 col0\" >None_None</th>\n",
       "      <th id=\"T_4e33d_level0_col1\" class=\"col_heading level0 col1\" >None_kojima-01</th>\n",
       "      <th id=\"T_4e33d_level0_col2\" class=\"col_heading level0 col2\" >None_kojima-03</th>\n",
       "      <th id=\"T_4e33d_level0_col3\" class=\"col_heading level0 col3\" >None_kojima-09</th>\n",
       "      <th id=\"T_4e33d_level0_col4\" class=\"col_heading level0 col4\" >None_zhou-01</th>\n",
       "      <th id=\"T_4e33d_level0_col5\" class=\"col_heading level0 col5\" >qa-01_None</th>\n",
       "      <th id=\"T_4e33d_level0_col6\" class=\"col_heading level0 col6\" >qa-05_None</th>\n",
       "      <th id=\"T_4e33d_level0_col7\" class=\"col_heading level0 col7\" >qa-08_None</th>\n",
       "      <th id=\"T_4e33d_level0_col8\" class=\"col_heading level0 col8\" >qa-09_None</th>\n",
       "      <th id=\"T_4e33d_level0_col9\" class=\"col_heading level0 col9\" >qa-10_None</th>\n",
       "      <th id=\"T_4e33d_level0_col10\" class=\"col_heading level0 col10\" >qa-12_None</th>\n",
       "      <th id=\"T_4e33d_level0_col11\" class=\"col_heading level0 col11\" >qa-13_None</th>\n",
       "      <th id=\"T_4e33d_level0_col12\" class=\"col_heading level0 col12\" >qa-16_None</th>\n",
       "      <th id=\"T_4e33d_level0_col13\" class=\"col_heading level0 col13\" >qa-17_None</th>\n",
       "      <th id=\"T_4e33d_level0_col14\" class=\"col_heading level0 col14\" >zhou-01-ins_None</th>\n",
       "      <th id=\"T_4e33d_level0_col15\" class=\"col_heading level0 col15\" >zhou-01-ins_zhou-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_4e33d_level1_col0\" class=\"col_heading level1 col0\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col1\" class=\"col_heading level1 col1\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col3\" class=\"col_heading level1 col3\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col4\" class=\"col_heading level1 col4\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col5\" class=\"col_heading level1 col5\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col6\" class=\"col_heading level1 col6\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col7\" class=\"col_heading level1 col7\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col8\" class=\"col_heading level1 col8\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col9\" class=\"col_heading level1 col9\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col10\" class=\"col_heading level1 col10\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col11\" class=\"col_heading level1 col11\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col12\" class=\"col_heading level1 col12\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col13\" class=\"col_heading level1 col13\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col14\" class=\"col_heading level1 col14\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_4e33d_level1_col15\" class=\"col_heading level1 col15\" >gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_4e33d_row0_col0\" class=\"data row0 col0\" >0.72</td>\n",
       "      <td id=\"T_4e33d_row0_col1\" class=\"data row0 col1\" >0.67</td>\n",
       "      <td id=\"T_4e33d_row0_col2\" class=\"data row0 col2\" >0.63</td>\n",
       "      <td id=\"T_4e33d_row0_col3\" class=\"data row0 col3\" >0.70</td>\n",
       "      <td id=\"T_4e33d_row0_col4\" class=\"data row0 col4\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row0_col5\" class=\"data row0 col5\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row0_col6\" class=\"data row0 col6\" >0.69</td>\n",
       "      <td id=\"T_4e33d_row0_col7\" class=\"data row0 col7\" >0.62</td>\n",
       "      <td id=\"T_4e33d_row0_col8\" class=\"data row0 col8\" >0.64</td>\n",
       "      <td id=\"T_4e33d_row0_col9\" class=\"data row0 col9\" >0.68</td>\n",
       "      <td id=\"T_4e33d_row0_col10\" class=\"data row0 col10\" >0.63</td>\n",
       "      <td id=\"T_4e33d_row0_col11\" class=\"data row0 col11\" >0.61</td>\n",
       "      <td id=\"T_4e33d_row0_col12\" class=\"data row0 col12\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row0_col13\" class=\"data row0 col13\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row0_col14\" class=\"data row0 col14\" >0.72</td>\n",
       "      <td id=\"T_4e33d_row0_col15\" class=\"data row0 col15\" >0.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_4e33d_row1_col0\" class=\"data row1 col0\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row1_col1\" class=\"data row1 col1\" >0.59</td>\n",
       "      <td id=\"T_4e33d_row1_col2\" class=\"data row1 col2\" >0.59</td>\n",
       "      <td id=\"T_4e33d_row1_col3\" class=\"data row1 col3\" >0.51</td>\n",
       "      <td id=\"T_4e33d_row1_col4\" class=\"data row1 col4\" >0.65</td>\n",
       "      <td id=\"T_4e33d_row1_col5\" class=\"data row1 col5\" >0.54</td>\n",
       "      <td id=\"T_4e33d_row1_col6\" class=\"data row1 col6\" >0.53</td>\n",
       "      <td id=\"T_4e33d_row1_col7\" class=\"data row1 col7\" >0.46</td>\n",
       "      <td id=\"T_4e33d_row1_col8\" class=\"data row1 col8\" >0.55</td>\n",
       "      <td id=\"T_4e33d_row1_col9\" class=\"data row1 col9\" >0.56</td>\n",
       "      <td id=\"T_4e33d_row1_col10\" class=\"data row1 col10\" >0.59</td>\n",
       "      <td id=\"T_4e33d_row1_col11\" class=\"data row1 col11\" >0.49</td>\n",
       "      <td id=\"T_4e33d_row1_col12\" class=\"data row1 col12\" >0.60</td>\n",
       "      <td id=\"T_4e33d_row1_col13\" class=\"data row1 col13\" >0.56</td>\n",
       "      <td id=\"T_4e33d_row1_col14\" class=\"data row1 col14\" >0.54</td>\n",
       "      <td id=\"T_4e33d_row1_col15\" class=\"data row1 col15\" >0.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_4e33d_row2_col0\" class=\"data row2 col0\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row2_col1\" class=\"data row2 col1\" >0.47</td>\n",
       "      <td id=\"T_4e33d_row2_col2\" class=\"data row2 col2\" >0.50</td>\n",
       "      <td id=\"T_4e33d_row2_col3\" class=\"data row2 col3\" >0.50</td>\n",
       "      <td id=\"T_4e33d_row2_col4\" class=\"data row2 col4\" >0.48</td>\n",
       "      <td id=\"T_4e33d_row2_col5\" class=\"data row2 col5\" >0.47</td>\n",
       "      <td id=\"T_4e33d_row2_col6\" class=\"data row2 col6\" >0.42</td>\n",
       "      <td id=\"T_4e33d_row2_col7\" class=\"data row2 col7\" >0.45</td>\n",
       "      <td id=\"T_4e33d_row2_col8\" class=\"data row2 col8\" >0.47</td>\n",
       "      <td id=\"T_4e33d_row2_col9\" class=\"data row2 col9\" >0.49</td>\n",
       "      <td id=\"T_4e33d_row2_col10\" class=\"data row2 col10\" >0.53</td>\n",
       "      <td id=\"T_4e33d_row2_col11\" class=\"data row2 col11\" >0.41</td>\n",
       "      <td id=\"T_4e33d_row2_col12\" class=\"data row2 col12\" >0.48</td>\n",
       "      <td id=\"T_4e33d_row2_col13\" class=\"data row2 col13\" >0.53</td>\n",
       "      <td id=\"T_4e33d_row2_col14\" class=\"data row2 col14\" >0.44</td>\n",
       "      <td id=\"T_4e33d_row2_col15\" class=\"data row2 col15\" >0.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_4e33d_row3_col0\" class=\"data row3 col0\" >0.77</td>\n",
       "      <td id=\"T_4e33d_row3_col1\" class=\"data row3 col1\" >0.77</td>\n",
       "      <td id=\"T_4e33d_row3_col2\" class=\"data row3 col2\" >0.73</td>\n",
       "      <td id=\"T_4e33d_row3_col3\" class=\"data row3 col3\" >0.73</td>\n",
       "      <td id=\"T_4e33d_row3_col4\" class=\"data row3 col4\" >0.81</td>\n",
       "      <td id=\"T_4e33d_row3_col5\" class=\"data row3 col5\" >0.73</td>\n",
       "      <td id=\"T_4e33d_row3_col6\" class=\"data row3 col6\" >0.65</td>\n",
       "      <td id=\"T_4e33d_row3_col7\" class=\"data row3 col7\" >0.73</td>\n",
       "      <td id=\"T_4e33d_row3_col8\" class=\"data row3 col8\" >0.71</td>\n",
       "      <td id=\"T_4e33d_row3_col9\" class=\"data row3 col9\" >0.73</td>\n",
       "      <td id=\"T_4e33d_row3_col10\" class=\"data row3 col10\" >0.80</td>\n",
       "      <td id=\"T_4e33d_row3_col11\" class=\"data row3 col11\" >0.72</td>\n",
       "      <td id=\"T_4e33d_row3_col12\" class=\"data row3 col12\" >0.69</td>\n",
       "      <td id=\"T_4e33d_row3_col13\" class=\"data row3 col13\" >0.69</td>\n",
       "      <td id=\"T_4e33d_row3_col14\" class=\"data row3 col14\" >0.76</td>\n",
       "      <td id=\"T_4e33d_row3_col15\" class=\"data row3 col15\" >0.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_4e33d_row4_col0\" class=\"data row4 col0\" >0.57</td>\n",
       "      <td id=\"T_4e33d_row4_col1\" class=\"data row4 col1\" >0.56</td>\n",
       "      <td id=\"T_4e33d_row4_col2\" class=\"data row4 col2\" >0.52</td>\n",
       "      <td id=\"T_4e33d_row4_col3\" class=\"data row4 col3\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row4_col4\" class=\"data row4 col4\" >0.59</td>\n",
       "      <td id=\"T_4e33d_row4_col5\" class=\"data row4 col5\" >0.44</td>\n",
       "      <td id=\"T_4e33d_row4_col6\" class=\"data row4 col6\" >0.43</td>\n",
       "      <td id=\"T_4e33d_row4_col7\" class=\"data row4 col7\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row4_col8\" class=\"data row4 col8\" >0.62</td>\n",
       "      <td id=\"T_4e33d_row4_col9\" class=\"data row4 col9\" >0.56</td>\n",
       "      <td id=\"T_4e33d_row4_col10\" class=\"data row4 col10\" >0.50</td>\n",
       "      <td id=\"T_4e33d_row4_col11\" class=\"data row4 col11\" >0.64</td>\n",
       "      <td id=\"T_4e33d_row4_col12\" class=\"data row4 col12\" >0.63</td>\n",
       "      <td id=\"T_4e33d_row4_col13\" class=\"data row4 col13\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row4_col14\" class=\"data row4 col14\" >0.52</td>\n",
       "      <td id=\"T_4e33d_row4_col15\" class=\"data row4 col15\" >0.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row5\" class=\"row_heading level0 row5\" >worldtree</th>\n",
       "      <td id=\"T_4e33d_row5_col0\" class=\"data row5 col0\" >0.96</td>\n",
       "      <td id=\"T_4e33d_row5_col1\" class=\"data row5 col1\" >0.93</td>\n",
       "      <td id=\"T_4e33d_row5_col2\" class=\"data row5 col2\" >0.95</td>\n",
       "      <td id=\"T_4e33d_row5_col3\" class=\"data row5 col3\" >0.95</td>\n",
       "      <td id=\"T_4e33d_row5_col4\" class=\"data row5 col4\" >0.92</td>\n",
       "      <td id=\"T_4e33d_row5_col5\" class=\"data row5 col5\" >0.95</td>\n",
       "      <td id=\"T_4e33d_row5_col6\" class=\"data row5 col6\" >0.74</td>\n",
       "      <td id=\"T_4e33d_row5_col7\" class=\"data row5 col7\" >0.92</td>\n",
       "      <td id=\"T_4e33d_row5_col8\" class=\"data row5 col8\" >0.91</td>\n",
       "      <td id=\"T_4e33d_row5_col9\" class=\"data row5 col9\" >0.95</td>\n",
       "      <td id=\"T_4e33d_row5_col10\" class=\"data row5 col10\" >0.92</td>\n",
       "      <td id=\"T_4e33d_row5_col11\" class=\"data row5 col11\" >0.96</td>\n",
       "      <td id=\"T_4e33d_row5_col12\" class=\"data row5 col12\" >0.91</td>\n",
       "      <td id=\"T_4e33d_row5_col13\" class=\"data row5 col13\" >0.92</td>\n",
       "      <td id=\"T_4e33d_row5_col14\" class=\"data row5 col14\" >0.96</td>\n",
       "      <td id=\"T_4e33d_row5_col15\" class=\"data row5 col15\" >0.96</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_4e33d_level0_row6\" class=\"row_heading level0 row6\" >Average</th>\n",
       "      <td id=\"T_4e33d_row6_col0\" class=\"data row6 col0\" >0.70</td>\n",
       "      <td id=\"T_4e33d_row6_col1\" class=\"data row6 col1\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row6_col2\" class=\"data row6 col2\" >0.65</td>\n",
       "      <td id=\"T_4e33d_row6_col3\" class=\"data row6 col3\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row6_col4\" class=\"data row6 col4\" >0.68</td>\n",
       "      <td id=\"T_4e33d_row6_col5\" class=\"data row6 col5\" >0.63</td>\n",
       "      <td id=\"T_4e33d_row6_col6\" class=\"data row6 col6\" >0.58</td>\n",
       "      <td id=\"T_4e33d_row6_col7\" class=\"data row6 col7\" >0.63</td>\n",
       "      <td id=\"T_4e33d_row6_col8\" class=\"data row6 col8\" >0.65</td>\n",
       "      <td id=\"T_4e33d_row6_col9\" class=\"data row6 col9\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row6_col10\" class=\"data row6 col10\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row6_col11\" class=\"data row6 col11\" >0.64</td>\n",
       "      <td id=\"T_4e33d_row6_col12\" class=\"data row6 col12\" >0.65</td>\n",
       "      <td id=\"T_4e33d_row6_col13\" class=\"data row6 col13\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row6_col14\" class=\"data row6 col14\" >0.66</td>\n",
       "      <td id=\"T_4e33d_row6_col15\" class=\"data row6 col15\" >0.64</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f61546b11e0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = ts_100.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'generated_cots'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[134], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m coll \u001b[39m=\u001b[39m coll\u001b[39m.\u001b[39;49mfilter(\u001b[39mlambda\u001b[39;49;00m x: x[\u001b[39m\"\u001b[39;49m\u001b[39mgenerated_cots\u001b[39;49m\u001b[39m\"\u001b[39;49m][\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m][\u001b[39m\"\u001b[39;49m\u001b[39manswers\u001b[39;49m\u001b[39m\"\u001b[39;49m][\u001b[39m0\u001b[39;49m][\u001b[39m\"\u001b[39;49m\u001b[39mcorrect_answer\u001b[39;49m\u001b[39m\"\u001b[39;49m] \u001b[39m==\u001b[39;49m \u001b[39mTrue\u001b[39;49;00m)\n",
      "File \u001b[0;32m~/work/ThoughtSource/libs/cot/cot/dataloader.py:509\u001b[0m, in \u001b[0;36mCollection.filter\u001b[0;34m(self, filter_func, **kwargs)\u001b[0m\n\u001b[1;32m    507\u001b[0m     \u001b[39mfor\u001b[39;00m name \u001b[39min\u001b[39;00m filtered_collection\u001b[39m.\u001b[39m_cache:\n\u001b[1;32m    508\u001b[0m         \u001b[39mfor\u001b[39;00m split \u001b[39min\u001b[39;00m filtered_collection\u001b[39m.\u001b[39m_cache[name]:\n\u001b[0;32m--> 509\u001b[0m             filtered_collection[name][split] \u001b[39m=\u001b[39m filtered_collection[name][split]\u001b[39m.\u001b[39;49mfilter(filter_func, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    510\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    511\u001b[0m     \u001b[39mfor\u001b[39;00m name \u001b[39min\u001b[39;00m filtered_collection\u001b[39m.\u001b[39m_cache:\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:487\u001b[0m, in \u001b[0;36mtransmit_format.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    480\u001b[0m self_format \u001b[39m=\u001b[39m {\n\u001b[1;32m    481\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_type,\n\u001b[1;32m    482\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mformat_kwargs\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_kwargs,\n\u001b[1;32m    483\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mcolumns\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_columns,\n\u001b[1;32m    484\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39moutput_all_columns\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_output_all_columns,\n\u001b[1;32m    485\u001b[0m }\n\u001b[1;32m    486\u001b[0m \u001b[39m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 487\u001b[0m out: Union[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mDatasetDict\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    488\u001b[0m datasets: List[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(out\u001b[39m.\u001b[39mvalues()) \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(out, \u001b[39mdict\u001b[39m) \u001b[39melse\u001b[39;00m [out]\n\u001b[1;32m    489\u001b[0m \u001b[39m# re-apply format to the output\u001b[39;00m\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/fingerprint.py:458\u001b[0m, in \u001b[0;36mfingerprint_transform.<locals>._fingerprint.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    452\u001b[0m             kwargs[fingerprint_name] \u001b[39m=\u001b[39m update_fingerprint(\n\u001b[1;32m    453\u001b[0m                 \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fingerprint, transform, kwargs_for_fingerprint\n\u001b[1;32m    454\u001b[0m             )\n\u001b[1;32m    456\u001b[0m \u001b[39m# Call actual function\u001b[39;00m\n\u001b[0;32m--> 458\u001b[0m out \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    460\u001b[0m \u001b[39m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[39;00m\n\u001b[1;32m    462\u001b[0m \u001b[39mif\u001b[39;00m inplace:  \u001b[39m# update after calling func so that the fingerprint doesn't change if the function fails\u001b[39;00m\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:2456\u001b[0m, in \u001b[0;36mDataset.filter\u001b[0;34m(self, function, with_indices, input_columns, batched, batch_size, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m   2453\u001b[0m \u001b[39mif\u001b[39;00m function \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   2454\u001b[0m     function \u001b[39m=\u001b[39m \u001b[39mlambda\u001b[39;00m x: \u001b[39mTrue\u001b[39;00m  \u001b[39m# noqa: E731\u001b[39;00m\n\u001b[0;32m-> 2456\u001b[0m indices \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmap(\n\u001b[1;32m   2457\u001b[0m     function\u001b[39m=\u001b[39;49mpartial(\n\u001b[1;32m   2458\u001b[0m         get_indices_from_mask_function, function, batched, with_indices, input_columns, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_indices\n\u001b[1;32m   2459\u001b[0m     ),\n\u001b[1;32m   2460\u001b[0m     with_indices\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m   2461\u001b[0m     features\u001b[39m=\u001b[39;49mFeatures({\u001b[39m\"\u001b[39;49m\u001b[39mindices\u001b[39;49m\u001b[39m\"\u001b[39;49m: Value(\u001b[39m\"\u001b[39;49m\u001b[39muint64\u001b[39;49m\u001b[39m\"\u001b[39;49m)}),\n\u001b[1;32m   2462\u001b[0m     batched\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m   2463\u001b[0m     batch_size\u001b[39m=\u001b[39;49mbatch_size,\n\u001b[1;32m   2464\u001b[0m     remove_columns\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcolumn_names,\n\u001b[1;32m   2465\u001b[0m     keep_in_memory\u001b[39m=\u001b[39;49mkeep_in_memory,\n\u001b[1;32m   2466\u001b[0m     load_from_cache_file\u001b[39m=\u001b[39;49mload_from_cache_file,\n\u001b[1;32m   2467\u001b[0m     cache_file_name\u001b[39m=\u001b[39;49mcache_file_name,\n\u001b[1;32m   2468\u001b[0m     writer_batch_size\u001b[39m=\u001b[39;49mwriter_batch_size,\n\u001b[1;32m   2469\u001b[0m     fn_kwargs\u001b[39m=\u001b[39;49mfn_kwargs,\n\u001b[1;32m   2470\u001b[0m     num_proc\u001b[39m=\u001b[39;49mnum_proc,\n\u001b[1;32m   2471\u001b[0m     suffix_template\u001b[39m=\u001b[39;49msuffix_template,\n\u001b[1;32m   2472\u001b[0m     new_fingerprint\u001b[39m=\u001b[39;49mnew_fingerprint,\n\u001b[1;32m   2473\u001b[0m     input_columns\u001b[39m=\u001b[39;49minput_columns,\n\u001b[1;32m   2474\u001b[0m     desc\u001b[39m=\u001b[39;49mdesc,\n\u001b[1;32m   2475\u001b[0m )\n\u001b[1;32m   2476\u001b[0m new_dataset \u001b[39m=\u001b[39m copy\u001b[39m.\u001b[39mdeepcopy(\u001b[39mself\u001b[39m)\n\u001b[1;32m   2477\u001b[0m new_dataset\u001b[39m.\u001b[39m_indices \u001b[39m=\u001b[39m indices\u001b[39m.\u001b[39mdata\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:1955\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m   1952\u001b[0m disable_tqdm \u001b[39m=\u001b[39m \u001b[39mnot\u001b[39;00m logging\u001b[39m.\u001b[39mis_progress_bar_enabled()\n\u001b[1;32m   1954\u001b[0m \u001b[39mif\u001b[39;00m num_proc \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m num_proc \u001b[39m==\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m-> 1955\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_map_single(\n\u001b[1;32m   1956\u001b[0m         function\u001b[39m=\u001b[39;49mfunction,\n\u001b[1;32m   1957\u001b[0m         with_indices\u001b[39m=\u001b[39;49mwith_indices,\n\u001b[1;32m   1958\u001b[0m         with_rank\u001b[39m=\u001b[39;49mwith_rank,\n\u001b[1;32m   1959\u001b[0m         input_columns\u001b[39m=\u001b[39;49minput_columns,\n\u001b[1;32m   1960\u001b[0m         batched\u001b[39m=\u001b[39;49mbatched,\n\u001b[1;32m   1961\u001b[0m         batch_size\u001b[39m=\u001b[39;49mbatch_size,\n\u001b[1;32m   1962\u001b[0m         drop_last_batch\u001b[39m=\u001b[39;49mdrop_last_batch,\n\u001b[1;32m   1963\u001b[0m         remove_columns\u001b[39m=\u001b[39;49mremove_columns,\n\u001b[1;32m   1964\u001b[0m         keep_in_memory\u001b[39m=\u001b[39;49mkeep_in_memory,\n\u001b[1;32m   1965\u001b[0m         load_from_cache_file\u001b[39m=\u001b[39;49mload_from_cache_file,\n\u001b[1;32m   1966\u001b[0m         cache_file_name\u001b[39m=\u001b[39;49mcache_file_name,\n\u001b[1;32m   1967\u001b[0m         writer_batch_size\u001b[39m=\u001b[39;49mwriter_batch_size,\n\u001b[1;32m   1968\u001b[0m         features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m   1969\u001b[0m         disable_nullable\u001b[39m=\u001b[39;49mdisable_nullable,\n\u001b[1;32m   1970\u001b[0m         fn_kwargs\u001b[39m=\u001b[39;49mfn_kwargs,\n\u001b[1;32m   1971\u001b[0m         new_fingerprint\u001b[39m=\u001b[39;49mnew_fingerprint,\n\u001b[1;32m   1972\u001b[0m         disable_tqdm\u001b[39m=\u001b[39;49mdisable_tqdm,\n\u001b[1;32m   1973\u001b[0m         desc\u001b[39m=\u001b[39;49mdesc,\n\u001b[1;32m   1974\u001b[0m     )\n\u001b[1;32m   1975\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   1977\u001b[0m     \u001b[39mdef\u001b[39;00m \u001b[39mformat_cache_file_name\u001b[39m(cache_file_name, rank):\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:520\u001b[0m, in \u001b[0;36mtransmit_tasks.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    518\u001b[0m     \u001b[39mself\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mpop(\u001b[39m\"\u001b[39m\u001b[39mself\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m    519\u001b[0m \u001b[39m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 520\u001b[0m out: Union[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mDatasetDict\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    521\u001b[0m datasets: List[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(out\u001b[39m.\u001b[39mvalues()) \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(out, \u001b[39mdict\u001b[39m) \u001b[39melse\u001b[39;00m [out]\n\u001b[1;32m    522\u001b[0m \u001b[39mfor\u001b[39;00m dataset \u001b[39min\u001b[39;00m datasets:\n\u001b[1;32m    523\u001b[0m     \u001b[39m# Remove task templates if a column mapping of the template is no longer valid\u001b[39;00m\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:487\u001b[0m, in \u001b[0;36mtransmit_format.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    480\u001b[0m self_format \u001b[39m=\u001b[39m {\n\u001b[1;32m    481\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_type,\n\u001b[1;32m    482\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mformat_kwargs\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_kwargs,\n\u001b[1;32m    483\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mcolumns\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_columns,\n\u001b[1;32m    484\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39moutput_all_columns\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_output_all_columns,\n\u001b[1;32m    485\u001b[0m }\n\u001b[1;32m    486\u001b[0m \u001b[39m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 487\u001b[0m out: Union[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mDatasetDict\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    488\u001b[0m datasets: List[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(out\u001b[39m.\u001b[39mvalues()) \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(out, \u001b[39mdict\u001b[39m) \u001b[39melse\u001b[39;00m [out]\n\u001b[1;32m    489\u001b[0m \u001b[39m# re-apply format to the output\u001b[39;00m\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/fingerprint.py:458\u001b[0m, in \u001b[0;36mfingerprint_transform.<locals>._fingerprint.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    452\u001b[0m             kwargs[fingerprint_name] \u001b[39m=\u001b[39m update_fingerprint(\n\u001b[1;32m    453\u001b[0m                 \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fingerprint, transform, kwargs_for_fingerprint\n\u001b[1;32m    454\u001b[0m             )\n\u001b[1;32m    456\u001b[0m \u001b[39m# Call actual function\u001b[39;00m\n\u001b[0;32m--> 458\u001b[0m out \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    460\u001b[0m \u001b[39m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[39;00m\n\u001b[1;32m    462\u001b[0m \u001b[39mif\u001b[39;00m inplace:  \u001b[39m# update after calling func so that the fingerprint doesn't change if the function fails\u001b[39;00m\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:2339\u001b[0m, in \u001b[0;36mDataset._map_single\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, disable_tqdm, desc, cache_only)\u001b[0m\n\u001b[1;32m   2335\u001b[0m indices \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(\n\u001b[1;32m   2336\u001b[0m     \u001b[39mrange\u001b[39m(\u001b[39m*\u001b[39m(\u001b[39mslice\u001b[39m(i, i \u001b[39m+\u001b[39m batch_size)\u001b[39m.\u001b[39mindices(input_dataset\u001b[39m.\u001b[39mnum_rows)))\n\u001b[1;32m   2337\u001b[0m )  \u001b[39m# Something simpler?\u001b[39;00m\n\u001b[1;32m   2338\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 2339\u001b[0m     batch \u001b[39m=\u001b[39m apply_function_on_filtered_inputs(\n\u001b[1;32m   2340\u001b[0m         batch,\n\u001b[1;32m   2341\u001b[0m         indices,\n\u001b[1;32m   2342\u001b[0m         check_same_num_examples\u001b[39m=\u001b[39;49m\u001b[39mlen\u001b[39;49m(input_dataset\u001b[39m.\u001b[39;49mlist_indexes()) \u001b[39m>\u001b[39;49m \u001b[39m0\u001b[39;49m,\n\u001b[1;32m   2343\u001b[0m         offset\u001b[39m=\u001b[39;49moffset,\n\u001b[1;32m   2344\u001b[0m     )\n\u001b[1;32m   2345\u001b[0m \u001b[39mexcept\u001b[39;00m NumExamplesMismatchError:\n\u001b[1;32m   2346\u001b[0m     \u001b[39mraise\u001b[39;00m DatasetTransformationNotAllowedError(\n\u001b[1;32m   2347\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mUsing `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m   2348\u001b[0m     ) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:2220\u001b[0m, in \u001b[0;36mDataset._map_single.<locals>.apply_function_on_filtered_inputs\u001b[0;34m(inputs, indices, check_same_num_examples, offset)\u001b[0m\n\u001b[1;32m   2218\u001b[0m \u001b[39mif\u001b[39;00m with_rank:\n\u001b[1;32m   2219\u001b[0m     additional_args \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m (rank,)\n\u001b[0;32m-> 2220\u001b[0m processed_inputs \u001b[39m=\u001b[39m function(\u001b[39m*\u001b[39;49mfn_args, \u001b[39m*\u001b[39;49madditional_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mfn_kwargs)\n\u001b[1;32m   2221\u001b[0m \u001b[39mif\u001b[39;00m update_data \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   2222\u001b[0m     \u001b[39m# Check if the function returns updated examples\u001b[39;00m\n\u001b[1;32m   2223\u001b[0m     update_data \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(processed_inputs, (Mapping, pa\u001b[39m.\u001b[39mTable))\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:1915\u001b[0m, in \u001b[0;36mDataset.map.<locals>.decorate.<locals>.decorated\u001b[0;34m(item, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1911\u001b[0m decorated_item \u001b[39m=\u001b[39m (\n\u001b[1;32m   1912\u001b[0m     Example(item, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m batched \u001b[39melse\u001b[39;00m Batch(item, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures)\n\u001b[1;32m   1913\u001b[0m )\n\u001b[1;32m   1914\u001b[0m \u001b[39m# Use the LazyDict internally, while mapping the function\u001b[39;00m\n\u001b[0;32m-> 1915\u001b[0m result \u001b[39m=\u001b[39m f(decorated_item, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1916\u001b[0m \u001b[39m# Return a standard dict\u001b[39;00m\n\u001b[1;32m   1917\u001b[0m \u001b[39mreturn\u001b[39;00m result\u001b[39m.\u001b[39mdata \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(result, LazyDict) \u001b[39melse\u001b[39;00m result\n",
      "File \u001b[0;32m~/work/ThoughtSource/venv/lib/python3.10/site-packages/datasets/arrow_dataset.py:4022\u001b[0m, in \u001b[0;36mget_indices_from_mask_function\u001b[0;34m(function, batched, with_indices, input_columns, indices_mapping, *args, **fn_kwargs)\u001b[0m\n\u001b[1;32m   4019\u001b[0m     \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(num_examples):\n\u001b[1;32m   4020\u001b[0m         example \u001b[39m=\u001b[39m {key: batch[key][i] \u001b[39mfor\u001b[39;00m key \u001b[39min\u001b[39;00m batch}\n\u001b[1;32m   4021\u001b[0m         mask\u001b[39m.\u001b[39mappend(\n\u001b[0;32m-> 4022\u001b[0m             function(example, indices[i], \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfn_kwargs) \u001b[39mif\u001b[39;00m with_indices \u001b[39melse\u001b[39;00m function(example, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mfn_kwargs)\n\u001b[1;32m   4023\u001b[0m         )\n\u001b[1;32m   4024\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   4025\u001b[0m     \u001b[39m# inputs is a list of columns\u001b[39;00m\n\u001b[1;32m   4026\u001b[0m     columns: List[List[Any]] \u001b[39m=\u001b[39m inputs\n",
      "Cell \u001b[0;32mIn[134], line 1\u001b[0m, in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0m coll \u001b[39m=\u001b[39m coll\u001b[39m.\u001b[39mfilter(\u001b[39mlambda\u001b[39;00m x: x[\u001b[39m\"\u001b[39;49m\u001b[39mgenerated_cots\u001b[39;49m\u001b[39m\"\u001b[39;49m][\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39manswers\u001b[39m\u001b[39m\"\u001b[39m][\u001b[39m0\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mcorrect_answer\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m==\u001b[39m \u001b[39mTrue\u001b[39;00m)\n",
      "\u001b[0;31mKeyError\u001b[0m: 'generated_cots'"
     ]
    }
   ],
   "source": [
    "coll = coll.filter(lambda x: x[\"generated_cots\"][-1][\"answers\"][0][\"correct_answer\"] == True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| open_book_qa   | -       | -       | 21     |\n",
       "| strategy_qa    | 11      | -       | -      |\n",
       "| worldtree      | -       | -       | 7      |\n",
       "| commonsense_qa | -       | 12      | -      |\n",
       "| med_qa         | -       | -       | 20     |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'medmc_qa', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 121,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| open_book_qa   | -       | -       | 21     |\n",
       "| strategy_qa    | 11      | -       | -      |\n",
       "| worldtree      | -       | -       | 7      |\n",
       "| commonsense_qa | -       | 12      | -      |\n",
       "| med_qa         | -       | -       | 20     |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'medmc_qa', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| open_book_qa   | -       | -       | 21     |\n",
       "| strategy_qa    | 11      | -       | -      |\n",
       "| worldtree      | -       | -       | 7      |\n",
       "| commonsense_qa | -       | 0       | -      |\n",
       "| med_qa         | -       | -       | 0      |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'medmc_qa', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll.filter(lambda x: len(x[\"generated_cot\"][1][\"annotations\"]) > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"ts_hard_v1_ts_cots\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| open_book_qa   | -       | -       | 21     |\n",
       "| strategy_qa    | 11      | -       | -      |\n",
       "| worldtree      | -       | -       | 7      |\n",
       "| commonsense_qa | -       | 12      | -      |\n",
       "| med_qa         | -       | -       | 20     |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'medmc_qa', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| open_book_qa   | -       | -       | 21     |\n",
       "| strategy_qa    | 11      | -       | -      |\n",
       "| worldtree      | -       | -       | 7      |\n",
       "| commonsense_qa | -       | 12      | -      |\n",
       "| med_qa         | -       | -       | 20     |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'medmc_qa', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"ts_hard_v3.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.select_generated_cots()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = coll.select(\"all\", number_samples=2, random_samples=True, seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": None,\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"api_time_interval\": 10,\n",
    "    \"engine\": \"gpt-4\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.generate(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"gpt-4_None\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.load_thoughtsource_100(\"all\",load_pregenerated_cots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": \"zhou-01-ins\",\n",
    "    \"cot_trigger_keys\": \"zhou-01\",\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.generate(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"thoughtsource_100\" + \"_\" + config['api_service'] + \"_\" + config['engine'].replace(\"/\", \"_\") + \"_\" + join_strings(config[\"instruction_keys\"]) + \"_\" + join_strings(config[\"cot_trigger_keys\"]) + \".json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# API error, did not work\n",
    "config={\n",
    "    \"instruction_keys\": \"qa-08\", # API error, did not work\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll1 = Collection.from_json(\"qa-05.json\")\n",
    "coll2 = Collection.from_json(\"qa-06.json\")\n",
    "coll3 = Collection.from_json(\"qa-07.json\")\n",
    "collz = Collection.from_json(\"zhou-01_cot.json\")\n",
    "collzi = Collection.from_json(\"zhou-01_ins.json\")\n",
    "\n",
    "coll_merged = coll1.merge(coll2)\n",
    "coll_merged = coll_merged.merge(coll3)\n",
    "coll_merged = coll_merged.merge(collz)\n",
    "coll_merged = coll_merged.merge(collzi)\n",
    "coll_merged.dump(\"current\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll_merged.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll = Collection.from_json(\"ts_1_sel_gen.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.select_generated_cots(model=[\"flan-T5-xxl\", \"text-davinci-003\", \"gpt-3.5-turbo\"])\n",
    "# coll.select_generated_cots(cot_trigger=[\"kojima-01\", \"zhou-01\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.select_generated_cots(model=[\"flan-T5-xxl\", \"text-davinci-003\", \"gpt-3.5-turbo\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### merge into ts_100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll = Collection.from_json(\"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_zhou-01-ins_None.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/kon/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "ts_100 = Collection.load_thoughtsource_100()\n",
    "# ts_100 = Collection.from_json(\"thoughtsource_100.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_colls = [\n",
    "    \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_None_kojima-03.json\",\n",
    "    \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_None_kojima-09.json\",\n",
    "    \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_qa-01_None.json\",\n",
    "    # \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_qa-05_None.json\",\n",
    "    # \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_zhou-01-ins_None.json\",\n",
    "    # \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-3.5-turbo_zhou-01-ins_zhou-01.json\",\n",
    "    \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-4_None_zhou-01.json\",\n",
    "    \"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_chat_gpt-4_None_None.json\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for coll_path in new_colls:\n",
    "    coll = Collection.from_json(coll_path)\n",
    "    ts_100 = ts_100.merge(coll)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100.dump(\"thoughtsource_100.json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.load_thoughtsource_100()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.select_generated_cots(author='thoughtsource', model=[\"gpt-3.5-turbo\", \"flan-T5-xxl\"], cot_trigger=None, instruction=None)\n",
    "coll.select_generated_cots(\n",
    "    author=\n",
    "        'thoughtsource',\n",
    "    model=[\n",
    "        # 'gpt-3.5-turbo',\n",
    "        'gpt-4',\n",
    "        # 'flan-T5-xxl',\n",
    "    ],\n",
    "    cot_trigger=[\n",
    "        # None,\n",
    "        'zhou-01'\n",
    "    ],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 28      | -      |\n",
       "| med_qa         | -       | -       | 24     |\n",
       "| medmc_qa       | -       | 30      | -      |\n",
       "| open_book_qa   | -       | -       | 5      |\n",
       "| strategy_qa    | 20      | -       | -      |\n",
       "| worldtree      | -       | -       | 1      |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 172,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll.select_generated_cots(answer=False)\n",
    "coll = coll.filter(lambda x: len(x[\"generated_cot\"])==1)\n",
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll1 = coll.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.load_thoughtsource_100()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.select_generated_cots(\n",
    "    author=\n",
    "        'thoughtsource',\n",
    "    model=[\n",
    "        # 'gpt-3.5-turbo',\n",
    "        'gpt-4',\n",
    "        # 'flan-T5-xxl',\n",
    "    ],\n",
    "    cot_trigger=[\n",
    "        None,\n",
    "        # 'zhou-01'\n",
    "    ],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 75      | -      |\n",
       "| med_qa         | -       | -       | 73     |\n",
       "| medmc_qa       | -       | 69      | -      |\n",
       "| open_book_qa   | -       | -       | 92     |\n",
       "| strategy_qa    | 71      | -       | -      |\n",
       "| worldtree      | -       | -       | 99     |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 176,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll.select_generated_cots(answer=True)\n",
    "coll = coll.filter(lambda x: len(x[\"generated_cot\"])==1)\n",
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll2 = coll.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = coll1.merge(coll2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e33c96c7c3814b10ae99a8a4abc6babc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7426828714045d49041be66544f78a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aef1ebd75784468ca812f45ed87d3cb6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6632aa4c4c2c49b28b95733e56af87c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "124319460b7e49458cdefbc414f2cfdd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc2f3230542845c98e39bf04e5db14fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll = coll.filter(lambda x: len(x[\"generated_cot\"])==2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 7       | -      |\n",
       "| med_qa         | -       | -       | 6      |\n",
       "| medmc_qa       | -       | 8       | -      |\n",
       "| open_book_qa   | -       | -       | 2      |\n",
       "| strategy_qa    | 3       | -       | -      |\n",
       "| worldtree      | -       | -       | 0      |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 180,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.unload_datasets([\"worldtree\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 7       | -      |\n",
       "| med_qa         | -       | -       | 6      |\n",
       "| medmc_qa       | -       | 8       | -      |\n",
       "| open_book_qa   | -       | -       | 2      |\n",
       "| strategy_qa    | 3       | -       | -      |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'pubmed_qa', 'qed', 'svamp', 'worldtree']"
      ]
     },
     "execution_count": 182,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61ab7b3172224a9191b54cc737982727",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/7 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b92acf8d817c4a44ae71ae59504d1893",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/6 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3214aa794f44631a09c405f1f38e0c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09bdfb1489a44c9d93b23699491c6a74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8dae271ff71544b49e391d5b32485380",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_b4b43_row0_col0, #T_b4b43_row1_col0, #T_b4b43_row2_col0, #T_b4b43_row3_col0, #T_b4b43_row4_col0, #T_b4b43_row5_col0 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_b4b43\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_b4b43_level0_col0\" class=\"col_heading level0 col0\" >None_None</th>\n",
       "      <th id=\"T_b4b43_level0_col1\" class=\"col_heading level0 col1\" >None_zhou-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_b4b43_level1_col0\" class=\"col_heading level1 col0\" >gpt-4</th>\n",
       "      <th id=\"T_b4b43_level1_col1\" class=\"col_heading level1 col1\" >gpt-4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_b4b43_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_b4b43_row0_col0\" class=\"data row0 col0\" >1.00</td>\n",
       "      <td id=\"T_b4b43_row0_col1\" class=\"data row0 col1\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4b43_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_b4b43_row1_col0\" class=\"data row1 col0\" >1.00</td>\n",
       "      <td id=\"T_b4b43_row1_col1\" class=\"data row1 col1\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4b43_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_b4b43_row2_col0\" class=\"data row2 col0\" >1.00</td>\n",
       "      <td id=\"T_b4b43_row2_col1\" class=\"data row2 col1\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4b43_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_b4b43_row3_col0\" class=\"data row3 col0\" >1.00</td>\n",
       "      <td id=\"T_b4b43_row3_col1\" class=\"data row3 col1\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4b43_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_b4b43_row4_col0\" class=\"data row4 col0\" >1.00</td>\n",
       "      <td id=\"T_b4b43_row4_col1\" class=\"data row4 col1\" >0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4b43_level0_row5\" class=\"row_heading level0 row5\" >Average</th>\n",
       "      <td id=\"T_b4b43_row5_col0\" class=\"data row5 col0\" >1.00</td>\n",
       "      <td id=\"T_b4b43_row5_col1\" class=\"data row5 col1\" >0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f9110949f30>"
      ]
     },
     "execution_count": 183,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = coll.evaluate()\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0599f84b21964aa88f373b61d016cbe9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1fec405209a5444ab7708c353a85fdc4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "692f0d8421e74d7092c3161cd9905498",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31c754743c2e48e9967580ffe7c0ec74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9dc91b138ad74bb29731bbff036690da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.dump(\"filtered_ts_100_gpt-4_None_correct_zhou-01_false.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'coll' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39meval\u001b[39m \u001b[39m=\u001b[39m coll\u001b[39m.\u001b[39mevaluate()\n\u001b[1;32m      2\u001b[0m table1 \u001b[39m=\u001b[39m evaluation_as_table(\u001b[39meval\u001b[39m)\n\u001b[1;32m      3\u001b[0m table1\n",
      "\u001b[0;31mNameError\u001b[0m: name 'coll' is not defined"
     ]
    }
   ],
   "source": [
    "eval = coll.evaluate()\n",
    "table1 = evaluation_as_table(eval)\n",
    "table1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_beda1_row0_col1, #T_beda1_row1_col18, #T_beda1_row2_col18, #T_beda1_row3_col18, #T_beda1_row4_col18, #T_beda1_row5_col3, #T_beda1_row6_col18 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_beda1\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_beda1_level0_col0\" class=\"col_heading level0 col0\" colspan=\"6\">None_None</th>\n",
       "      <th id=\"T_beda1_level0_col6\" class=\"col_heading level0 col6\" colspan=\"5\">None_kojima-01</th>\n",
       "      <th id=\"T_beda1_level0_col11\" class=\"col_heading level0 col11\" >None_kojima-03</th>\n",
       "      <th id=\"T_beda1_level0_col12\" class=\"col_heading level0 col12\" >None_kojima-09</th>\n",
       "      <th id=\"T_beda1_level0_col13\" class=\"col_heading level0 col13\" >None_lievin-01</th>\n",
       "      <th id=\"T_beda1_level0_col14\" class=\"col_heading level0 col14\" >None_lievin-02</th>\n",
       "      <th id=\"T_beda1_level0_col15\" class=\"col_heading level0 col15\" >None_lievin-03</th>\n",
       "      <th id=\"T_beda1_level0_col16\" class=\"col_heading level0 col16\" >None_lievin-10</th>\n",
       "      <th id=\"T_beda1_level0_col17\" class=\"col_heading level0 col17\" colspan=\"2\">None_zhou-01</th>\n",
       "      <th id=\"T_beda1_level0_col19\" class=\"col_heading level0 col19\" >qa-01_None</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_beda1_level1_col0\" class=\"col_heading level1 col0\" >command-xlarge-nightly</th>\n",
       "      <th id=\"T_beda1_level1_col1\" class=\"col_heading level1 col1\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_beda1_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_beda1_level1_col3\" class=\"col_heading level1 col3\" >gpt-4</th>\n",
       "      <th id=\"T_beda1_level1_col4\" class=\"col_heading level1 col4\" >text-davinci-002</th>\n",
       "      <th id=\"T_beda1_level1_col5\" class=\"col_heading level1 col5\" >text-davinci-003</th>\n",
       "      <th id=\"T_beda1_level1_col6\" class=\"col_heading level1 col6\" >command-xlarge-nightly</th>\n",
       "      <th id=\"T_beda1_level1_col7\" class=\"col_heading level1 col7\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_beda1_level1_col8\" class=\"col_heading level1 col8\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_beda1_level1_col9\" class=\"col_heading level1 col9\" >text-davinci-002</th>\n",
       "      <th id=\"T_beda1_level1_col10\" class=\"col_heading level1 col10\" >text-davinci-003</th>\n",
       "      <th id=\"T_beda1_level1_col11\" class=\"col_heading level1 col11\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_beda1_level1_col12\" class=\"col_heading level1 col12\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_beda1_level1_col13\" class=\"col_heading level1 col13\" >text-davinci-002</th>\n",
       "      <th id=\"T_beda1_level1_col14\" class=\"col_heading level1 col14\" >text-davinci-002</th>\n",
       "      <th id=\"T_beda1_level1_col15\" class=\"col_heading level1 col15\" >text-davinci-002</th>\n",
       "      <th id=\"T_beda1_level1_col16\" class=\"col_heading level1 col16\" >text-davinci-002</th>\n",
       "      <th id=\"T_beda1_level1_col17\" class=\"col_heading level1 col17\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_beda1_level1_col18\" class=\"col_heading level1 col18\" >gpt-4</th>\n",
       "      <th id=\"T_beda1_level1_col19\" class=\"col_heading level1 col19\" >gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_beda1_row0_col0\" class=\"data row0 col0\" >0.51</td>\n",
       "      <td id=\"T_beda1_row0_col1\" class=\"data row0 col1\" >0.87</td>\n",
       "      <td id=\"T_beda1_row0_col2\" class=\"data row0 col2\" >0.72</td>\n",
       "      <td id=\"T_beda1_row0_col3\" class=\"data row0 col3\" >0.75</td>\n",
       "      <td id=\"T_beda1_row0_col4\" class=\"data row0 col4\" >0.69</td>\n",
       "      <td id=\"T_beda1_row0_col5\" class=\"data row0 col5\" >0.72</td>\n",
       "      <td id=\"T_beda1_row0_col6\" class=\"data row0 col6\" >0.53</td>\n",
       "      <td id=\"T_beda1_row0_col7\" class=\"data row0 col7\" >0.81</td>\n",
       "      <td id=\"T_beda1_row0_col8\" class=\"data row0 col8\" >0.67</td>\n",
       "      <td id=\"T_beda1_row0_col9\" class=\"data row0 col9\" >0.66</td>\n",
       "      <td id=\"T_beda1_row0_col10\" class=\"data row0 col10\" >0.65</td>\n",
       "      <td id=\"T_beda1_row0_col11\" class=\"data row0 col11\" >0.63</td>\n",
       "      <td id=\"T_beda1_row0_col12\" class=\"data row0 col12\" >0.70</td>\n",
       "      <td id=\"T_beda1_row0_col13\" class=\"data row0 col13\" >nan</td>\n",
       "      <td id=\"T_beda1_row0_col14\" class=\"data row0 col14\" >nan</td>\n",
       "      <td id=\"T_beda1_row0_col15\" class=\"data row0 col15\" >nan</td>\n",
       "      <td id=\"T_beda1_row0_col16\" class=\"data row0 col16\" >nan</td>\n",
       "      <td id=\"T_beda1_row0_col17\" class=\"data row0 col17\" >0.66</td>\n",
       "      <td id=\"T_beda1_row0_col18\" class=\"data row0 col18\" >0.72</td>\n",
       "      <td id=\"T_beda1_row0_col19\" class=\"data row0 col19\" >0.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_beda1_row1_col0\" class=\"data row1 col0\" >0.23</td>\n",
       "      <td id=\"T_beda1_row1_col1\" class=\"data row1 col1\" >0.32</td>\n",
       "      <td id=\"T_beda1_row1_col2\" class=\"data row1 col2\" >0.58</td>\n",
       "      <td id=\"T_beda1_row1_col3\" class=\"data row1 col3\" >0.73</td>\n",
       "      <td id=\"T_beda1_row1_col4\" class=\"data row1 col4\" >0.41</td>\n",
       "      <td id=\"T_beda1_row1_col5\" class=\"data row1 col5\" >0.43</td>\n",
       "      <td id=\"T_beda1_row1_col6\" class=\"data row1 col6\" >0.31</td>\n",
       "      <td id=\"T_beda1_row1_col7\" class=\"data row1 col7\" >0.34</td>\n",
       "      <td id=\"T_beda1_row1_col8\" class=\"data row1 col8\" >0.59</td>\n",
       "      <td id=\"T_beda1_row1_col9\" class=\"data row1 col9\" >0.34</td>\n",
       "      <td id=\"T_beda1_row1_col10\" class=\"data row1 col10\" >0.43</td>\n",
       "      <td id=\"T_beda1_row1_col11\" class=\"data row1 col11\" >0.59</td>\n",
       "      <td id=\"T_beda1_row1_col12\" class=\"data row1 col12\" >0.51</td>\n",
       "      <td id=\"T_beda1_row1_col13\" class=\"data row1 col13\" >0.44</td>\n",
       "      <td id=\"T_beda1_row1_col14\" class=\"data row1 col14\" >0.50</td>\n",
       "      <td id=\"T_beda1_row1_col15\" class=\"data row1 col15\" >0.49</td>\n",
       "      <td id=\"T_beda1_row1_col16\" class=\"data row1 col16\" >0.45</td>\n",
       "      <td id=\"T_beda1_row1_col17\" class=\"data row1 col17\" >0.65</td>\n",
       "      <td id=\"T_beda1_row1_col18\" class=\"data row1 col18\" >0.76</td>\n",
       "      <td id=\"T_beda1_row1_col19\" class=\"data row1 col19\" >0.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_beda1_row2_col0\" class=\"data row2 col0\" >0.25</td>\n",
       "      <td id=\"T_beda1_row2_col1\" class=\"data row2 col1\" >0.34</td>\n",
       "      <td id=\"T_beda1_row2_col2\" class=\"data row2 col2\" >0.58</td>\n",
       "      <td id=\"T_beda1_row2_col3\" class=\"data row2 col3\" >0.69</td>\n",
       "      <td id=\"T_beda1_row2_col4\" class=\"data row2 col4\" >0.34</td>\n",
       "      <td id=\"T_beda1_row2_col5\" class=\"data row2 col5\" >0.40</td>\n",
       "      <td id=\"T_beda1_row2_col6\" class=\"data row2 col6\" >0.22</td>\n",
       "      <td id=\"T_beda1_row2_col7\" class=\"data row2 col7\" >0.35</td>\n",
       "      <td id=\"T_beda1_row2_col8\" class=\"data row2 col8\" >0.47</td>\n",
       "      <td id=\"T_beda1_row2_col9\" class=\"data row2 col9\" >0.36</td>\n",
       "      <td id=\"T_beda1_row2_col10\" class=\"data row2 col10\" >0.36</td>\n",
       "      <td id=\"T_beda1_row2_col11\" class=\"data row2 col11\" >0.50</td>\n",
       "      <td id=\"T_beda1_row2_col12\" class=\"data row2 col12\" >0.50</td>\n",
       "      <td id=\"T_beda1_row2_col13\" class=\"data row2 col13\" >0.45</td>\n",
       "      <td id=\"T_beda1_row2_col14\" class=\"data row2 col14\" >0.41</td>\n",
       "      <td id=\"T_beda1_row2_col15\" class=\"data row2 col15\" >0.40</td>\n",
       "      <td id=\"T_beda1_row2_col16\" class=\"data row2 col16\" >0.42</td>\n",
       "      <td id=\"T_beda1_row2_col17\" class=\"data row2 col17\" >0.48</td>\n",
       "      <td id=\"T_beda1_row2_col18\" class=\"data row2 col18\" >0.70</td>\n",
       "      <td id=\"T_beda1_row2_col19\" class=\"data row2 col19\" >0.47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_beda1_row3_col0\" class=\"data row3 col0\" >0.59</td>\n",
       "      <td id=\"T_beda1_row3_col1\" class=\"data row3 col1\" >0.82</td>\n",
       "      <td id=\"T_beda1_row3_col2\" class=\"data row3 col2\" >0.77</td>\n",
       "      <td id=\"T_beda1_row3_col3\" class=\"data row3 col3\" >0.92</td>\n",
       "      <td id=\"T_beda1_row3_col4\" class=\"data row3 col4\" >0.67</td>\n",
       "      <td id=\"T_beda1_row3_col5\" class=\"data row3 col5\" >0.70</td>\n",
       "      <td id=\"T_beda1_row3_col6\" class=\"data row3 col6\" >0.38</td>\n",
       "      <td id=\"T_beda1_row3_col7\" class=\"data row3 col7\" >0.79</td>\n",
       "      <td id=\"T_beda1_row3_col8\" class=\"data row3 col8\" >0.77</td>\n",
       "      <td id=\"T_beda1_row3_col9\" class=\"data row3 col9\" >0.57</td>\n",
       "      <td id=\"T_beda1_row3_col10\" class=\"data row3 col10\" >0.67</td>\n",
       "      <td id=\"T_beda1_row3_col11\" class=\"data row3 col11\" >0.73</td>\n",
       "      <td id=\"T_beda1_row3_col12\" class=\"data row3 col12\" >0.73</td>\n",
       "      <td id=\"T_beda1_row3_col13\" class=\"data row3 col13\" >nan</td>\n",
       "      <td id=\"T_beda1_row3_col14\" class=\"data row3 col14\" >nan</td>\n",
       "      <td id=\"T_beda1_row3_col15\" class=\"data row3 col15\" >nan</td>\n",
       "      <td id=\"T_beda1_row3_col16\" class=\"data row3 col16\" >nan</td>\n",
       "      <td id=\"T_beda1_row3_col17\" class=\"data row3 col17\" >0.81</td>\n",
       "      <td id=\"T_beda1_row3_col18\" class=\"data row3 col18\" >0.95</td>\n",
       "      <td id=\"T_beda1_row3_col19\" class=\"data row3 col19\" >0.73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_beda1_row4_col0\" class=\"data row4 col0\" >0.52</td>\n",
       "      <td id=\"T_beda1_row4_col1\" class=\"data row4 col1\" >0.61</td>\n",
       "      <td id=\"T_beda1_row4_col2\" class=\"data row4 col2\" >0.57</td>\n",
       "      <td id=\"T_beda1_row4_col3\" class=\"data row4 col3\" >0.71</td>\n",
       "      <td id=\"T_beda1_row4_col4\" class=\"data row4 col4\" >0.59</td>\n",
       "      <td id=\"T_beda1_row4_col5\" class=\"data row4 col5\" >0.53</td>\n",
       "      <td id=\"T_beda1_row4_col6\" class=\"data row4 col6\" >0.59</td>\n",
       "      <td id=\"T_beda1_row4_col7\" class=\"data row4 col7\" >0.69</td>\n",
       "      <td id=\"T_beda1_row4_col8\" class=\"data row4 col8\" >0.56</td>\n",
       "      <td id=\"T_beda1_row4_col9\" class=\"data row4 col9\" >0.46</td>\n",
       "      <td id=\"T_beda1_row4_col10\" class=\"data row4 col10\" >0.54</td>\n",
       "      <td id=\"T_beda1_row4_col11\" class=\"data row4 col11\" >0.52</td>\n",
       "      <td id=\"T_beda1_row4_col12\" class=\"data row4 col12\" >0.58</td>\n",
       "      <td id=\"T_beda1_row4_col13\" class=\"data row4 col13\" >nan</td>\n",
       "      <td id=\"T_beda1_row4_col14\" class=\"data row4 col14\" >nan</td>\n",
       "      <td id=\"T_beda1_row4_col15\" class=\"data row4 col15\" >nan</td>\n",
       "      <td id=\"T_beda1_row4_col16\" class=\"data row4 col16\" >nan</td>\n",
       "      <td id=\"T_beda1_row4_col17\" class=\"data row4 col17\" >0.59</td>\n",
       "      <td id=\"T_beda1_row4_col18\" class=\"data row4 col18\" >0.80</td>\n",
       "      <td id=\"T_beda1_row4_col19\" class=\"data row4 col19\" >0.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row5\" class=\"row_heading level0 row5\" >worldtree</th>\n",
       "      <td id=\"T_beda1_row5_col0\" class=\"data row5 col0\" >0.63</td>\n",
       "      <td id=\"T_beda1_row5_col1\" class=\"data row5 col1\" >0.85</td>\n",
       "      <td id=\"T_beda1_row5_col2\" class=\"data row5 col2\" >0.96</td>\n",
       "      <td id=\"T_beda1_row5_col3\" class=\"data row5 col3\" >0.99</td>\n",
       "      <td id=\"T_beda1_row5_col4\" class=\"data row5 col4\" >0.88</td>\n",
       "      <td id=\"T_beda1_row5_col5\" class=\"data row5 col5\" >0.91</td>\n",
       "      <td id=\"T_beda1_row5_col6\" class=\"data row5 col6\" >0.58</td>\n",
       "      <td id=\"T_beda1_row5_col7\" class=\"data row5 col7\" >0.80</td>\n",
       "      <td id=\"T_beda1_row5_col8\" class=\"data row5 col8\" >0.93</td>\n",
       "      <td id=\"T_beda1_row5_col9\" class=\"data row5 col9\" >0.78</td>\n",
       "      <td id=\"T_beda1_row5_col10\" class=\"data row5 col10\" >0.89</td>\n",
       "      <td id=\"T_beda1_row5_col11\" class=\"data row5 col11\" >0.95</td>\n",
       "      <td id=\"T_beda1_row5_col12\" class=\"data row5 col12\" >0.95</td>\n",
       "      <td id=\"T_beda1_row5_col13\" class=\"data row5 col13\" >nan</td>\n",
       "      <td id=\"T_beda1_row5_col14\" class=\"data row5 col14\" >nan</td>\n",
       "      <td id=\"T_beda1_row5_col15\" class=\"data row5 col15\" >nan</td>\n",
       "      <td id=\"T_beda1_row5_col16\" class=\"data row5 col16\" >nan</td>\n",
       "      <td id=\"T_beda1_row5_col17\" class=\"data row5 col17\" >0.92</td>\n",
       "      <td id=\"T_beda1_row5_col18\" class=\"data row5 col18\" >0.99</td>\n",
       "      <td id=\"T_beda1_row5_col19\" class=\"data row5 col19\" >0.95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_beda1_level0_row6\" class=\"row_heading level0 row6\" >Average</th>\n",
       "      <td id=\"T_beda1_row6_col0\" class=\"data row6 col0\" >0.46</td>\n",
       "      <td id=\"T_beda1_row6_col1\" class=\"data row6 col1\" >0.64</td>\n",
       "      <td id=\"T_beda1_row6_col2\" class=\"data row6 col2\" >0.70</td>\n",
       "      <td id=\"T_beda1_row6_col3\" class=\"data row6 col3\" >0.80</td>\n",
       "      <td id=\"T_beda1_row6_col4\" class=\"data row6 col4\" >0.60</td>\n",
       "      <td id=\"T_beda1_row6_col5\" class=\"data row6 col5\" >0.62</td>\n",
       "      <td id=\"T_beda1_row6_col6\" class=\"data row6 col6\" >0.44</td>\n",
       "      <td id=\"T_beda1_row6_col7\" class=\"data row6 col7\" >0.63</td>\n",
       "      <td id=\"T_beda1_row6_col8\" class=\"data row6 col8\" >0.66</td>\n",
       "      <td id=\"T_beda1_row6_col9\" class=\"data row6 col9\" >0.53</td>\n",
       "      <td id=\"T_beda1_row6_col10\" class=\"data row6 col10\" >0.59</td>\n",
       "      <td id=\"T_beda1_row6_col11\" class=\"data row6 col11\" >0.65</td>\n",
       "      <td id=\"T_beda1_row6_col12\" class=\"data row6 col12\" >0.66</td>\n",
       "      <td id=\"T_beda1_row6_col13\" class=\"data row6 col13\" >0.44</td>\n",
       "      <td id=\"T_beda1_row6_col14\" class=\"data row6 col14\" >0.45</td>\n",
       "      <td id=\"T_beda1_row6_col15\" class=\"data row6 col15\" >0.44</td>\n",
       "      <td id=\"T_beda1_row6_col16\" class=\"data row6 col16\" >0.44</td>\n",
       "      <td id=\"T_beda1_row6_col17\" class=\"data row6 col17\" >0.68</td>\n",
       "      <td id=\"T_beda1_row6_col18\" class=\"data row6 col18\" >0.82</td>\n",
       "      <td id=\"T_beda1_row6_col19\" class=\"data row6 col19\" >0.63</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fc15c331cf0>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"chat-gpt_vs_gpt4_vs_t5_zhou-and-none_cot.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll2 = Collection.from_json(\"gpt-4_None.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = coll.merge(coll2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll_1 = Collection.from_json(\"gpt-4_zhou-01.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = coll.merge(coll_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"gpt-3.5_None_and_zhou-01_vs_gpt-4_None_and_zhou-01.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = coll.select(\"all\", number_samples=20, random_samples=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"ts_20_random.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.select_generated_cots(author='thoughtsource', model=[\"gpt-3.5-turbo\", \"flan-T5-xxl\"], cot_trigger=\"kojima-01\", instruction=None)\n",
    "# coll.select_generated_cots(author='thoughtsource')\n",
    "# coll.select_generated_cots(model='gpt-3.5-turbo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.select_generated_cots(model='gpt-3.5-turbo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### thoughtsource_1 dataset selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll = coll.select(split=\"all\", number_samples=1, random_samples=True, seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.dump(\"thoutsource_1_selected\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
